"""Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST."""

import os
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms

# code modified from https://github.com/atuannguyen/DIRT/blob/main/domain_gen_rotatedmnist/mnist_loader.py
class MnistRotated(data_utils.Dataset):
    def __init__(self,
                 list_train_domains,
                 list_test_domain,
                 root,
                 mnist_subset=False,
                 transform=None,
                 train=True,
                 download=True,
                 all_data=False,
                 num_supervised=None,
                 not_eval=True):

        self.list_train_domains = list_train_domains
        self.list_test_domain = list_test_domain
        if mnist_subset == 'full':
            print('use full subset of MNIST!')
            self.mnist_subset = None
        else:
            self.mnist_subset = mnist_subset
            print(f'use subset {self.mnist_subset}')
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download
        self.all_data = all_data
        self.num_supervised = num_supervised
        if not not_eval:
            self.all_data = True # always use all data from test set
        self.not_eval = not_eval # load test MNIST dataset

        if self.train:
            self.train_data, self.train_labels, self.train_domain, self.train_angles = self._get_data()
        else:
            self.test_data, self.test_labels, self.test_domain = self._get_data()

    def load_inds(self):
        '''
        If specifyign a subset, load 1000 mnist samples with balanced class (100 samples
        for each class). If not, load 10000 mnist samples.
        :return: indices of mnist samples to be loaded
        '''
        if not self.mnist_subset:
            fullidx = np.array([])
            for i in range(10):
                fullidx = np.concatenate((fullidx,np.load(os.path.join(self.root, 'rotatedmnist/supervised_inds_' + str(i) + '.npy'))))
            return fullidx
        else:
            return np.load(os.path.join(self.root, 'rotatedmnist/supervised_inds_' + self.mnist_subset + '.npy'))

    def _get_data(self):
        if self.train:
            if self.not_eval:
                bs = 60000
            else:
                bs = 10000
                self.all_data = True
            train_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root,
                                                                      train=self.not_eval,
                                                                      download=self.download,
                                                                      transform=transforms.ToTensor()),
                                                       batch_size=bs,
                                                       shuffle=False)

            for i, (x, y) in enumerate(train_loader):
                mnist_imgs = x
                mnist_labels = y

            if not self.all_data:
                # Get labeled examples
                sup_inds = self.load_inds()
                mnist_labels = mnist_labels[sup_inds]
                mnist_imgs = mnist_imgs[sup_inds]
            else:
                pass

            if not self.num_supervised:
                self.num_supervised = int(mnist_imgs.shape[0])

            to_pil = transforms.ToPILImage()
            to_tensor = transforms.ToTensor()

            # Run transforms
            mnist_0_img = torch.zeros((self.num_supervised, 28, 28))
            mnist_15_img = torch.zeros((self.num_supervised, 28, 28))
            mnist_30_img = torch.zeros((self.num_supervised, 28, 28))
            mnist_45_img = torch.zeros((self.num_supervised, 28, 28))
            mnist_60_img = torch.zeros((self.num_supervised, 28, 28))
            mnist_75_img = torch.zeros((self.num_supervised, 28, 28))


            for i in range(len(mnist_imgs)):
                mnist_0_img[i] = to_tensor(to_pil(mnist_imgs[i]))

            for i in range(len(mnist_imgs)):
                mnist_15_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 15))

            for i in range(len(mnist_imgs)):
                mnist_30_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 30))

            for i in range(len(mnist_imgs)):
                mnist_45_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 45))

            for i in range(len(mnist_imgs)):
                mnist_60_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 60))

            for i in range(len(mnist_imgs)):
                mnist_75_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 75))

            # Choose subsets that should be included into the training
            training_list_img = []
            training_list_labels = []
            train_angles = []
            for domain in self.list_train_domains:
                if domain == '0':
                    training_list_img.append(mnist_0_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(0)
                if domain == '15':
                    training_list_img.append(mnist_15_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(15) 
                if domain == '30':
                    training_list_img.append(mnist_30_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(30) 
                if domain == '45':
                    training_list_img.append(mnist_45_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(45) 
                if domain == '60':
                    training_list_img.append(mnist_60_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(60) 
                if domain == '75':
                    training_list_img.append(mnist_75_img)
                    training_list_labels.append(mnist_labels)
                    train_angles.append(75) 

            # Stack
            train_imgs = torch.cat(training_list_img)
            train_labels = torch.cat(training_list_labels)

            # Create domain labels
            train_domains = torch.zeros(train_labels.size())
            for i in range(len(self.list_train_domains)):
                train_domains[i*self.num_supervised:(i+1)* self.num_supervised] += i
            # train_domains[0: self.num_supervised] += 0
            # train_domains[self.num_supervised: 2 * self.num_supervised] += 1
            # train_domains[2 * self.num_supervised: 3 * self.num_supervised] += 2
            # train_domains[3 * self.num_supervised: 4 * self.num_supervised] += 3
            # train_domains[4 * self.num_supervised: 5 * self.num_supervised] += 4

            # Shuffle everything one more time
            inds = np.arange(train_labels.size()[0])
            np.random.shuffle(inds)
            train_imgs = train_imgs[inds]
            train_labels = train_labels[inds]
            train_domains = train_domains[inds].long()

            ## Convert to onehot
            #y = torch.eye(10)
            #train_labels = y[train_labels]

            ## Convert to onehot
            #d = torch.eye(5)
            #train_domains = d[train_domains]

            return train_imgs.unsqueeze(1), train_labels, train_domains, train_angles

        else:
            if self.not_eval:
                bs = 60000
            else:
                bs = 10000
                self.all_data=True
            train_loader = torch.utils.data.DataLoader(datasets.MNIST(self.root,
                                                                      train=self.not_eval,
                                                                      download=self.download,
                                                                      transform=transforms.ToTensor()),
                                                       batch_size=bs,
                                                       shuffle=False)

            for i, (x, y) in enumerate(train_loader):
                mnist_imgs = x
                mnist_labels = y

            if not self.all_data:
                # Get num_supervised number of labeled examples
                sup_inds = self.load_inds()
                mnist_labels = mnist_labels[sup_inds]
                mnist_imgs = mnist_imgs[sup_inds]
            else:
                pass

            to_pil = transforms.ToPILImage()
            to_tensor = transforms.ToTensor()

            # Get angle
            rot_angle = int(self.list_test_domain[0])

            # Resize
            if not self.num_supervised:
                self.num_supervised = int(mnist_imgs.shape[0])
            mnist_imgs_rot = torch.zeros((self.num_supervised, 28, 28))

            for i in range(len(mnist_imgs)):
                mnist_imgs_rot[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), rot_angle))

            # Create domain labels
            test_domain = torch.zeros(mnist_labels.size()).long()

            ## Convert to onehot
            #y = torch.eye(10)
            #mnist_labels = y[mnist_labels]

            ## Convert to onehot
            #d = torch.eye(5)
            #test_domain = d[test_domain]

            return mnist_imgs_rot.unsqueeze(1), mnist_labels, test_domain

    def __len__(self):
        if self.train:
            return len(self.train_labels)
        else:
            return len(self.test_labels)

    def __getitem__(self, index):
        if self.train:
            x = self.train_data[index]
            y = self.train_labels[index]
            d = self.train_domain[index]

            if self.transform is not None:
                x = self.transform(x)
            
            return x, y, d
        else:
            x = self.test_data[index]
            y = self.test_labels[index]
            d = self.test_domain[index]

            if self.transform is not None:
                x = self.transform(x)
            
            return x,y


